#!/usr/bin/python3
from servicebase import ServiceBase
import aiohttp
import json
import traceback
import time
from aiohttp.client import ClientSession
from aiohttp import http_exceptions
import asyncio
import logging
import utils
import random
from utils import Module
from utils import Node
from utils import AuthverManager
from utils import AEScoder
from utils import Utils
ERR_WORKER_RANGE_OVERLAP = -10001
ERR_WORKER_RANGE_LEAK = -10002
ERR_NODE_TIMEOUT_OR_DEAD = -10003
ERR_MODULE_NOT_EXIST = -10004
ERR_MODULE_NOT_AVAILABLE = -10005
ERR_INVALID_REQUEST = -10006
NODE_STATUS_GREEN = 0
NODE_STATUS_YELLOW = 1
NODE_STATUS_RED = 2
NODE_STATUS_OFFLINE_UPDATING = 3

class RCMDMaster(ServiceBase):

    def __init__(self):
        self._pools = {}
        self._timestamp = 0
        self._ping_gap = 5
        self._last_ping = 0
        self._last_check_nodes = 0
        self._modules = {}
        self._config = {}
        self._ips_stat = {}
        self._app_access_point = {}
        self._kcs_config_point = {}
        ServiceBase.__init__(self)
    
    def load_config(self):
        def load_module(modules, m, conf, default):
            modules[m] = Module()
            modules[m].etc = conf['etc'] if 'etc' in conf else default['etc']
            modules[m].ping = conf['ping'] if 'ping' in conf else True
            modules[m].require = conf['require'] if 'etc' in conf else default['require']
            modules[m].nodes = []
            modules[m].https_nodes = []
            logging.error("load etc for module...%s"%(m))
            for access in conf['node']:
                node = Node()
                node.ip = access[0]
                node.port = access[1]
                node.etc = access[2]
                if len(access) > 3:
                    node.uri = access[3]
                modules[m].nodes.append(node)
            https_nodes = conf.get('https_node', [])
            for access in https_nodes:
                node = Node()
                node.ip = access[0]
                node.port = access[1]
                node.etc = access[2]
                if len(access) > 3:
                    node.uri = access[3]
                modules[m].https_nodes.append(node)
            modules[m].last = time.time()
        modules = {}
        frontier_all_nodes = []
        for module in self._config['module']:
            if module == 'frontier':
                default = self._config['module'][module]['default'] 
                for m in self._config['module'][module]:
                    if m == 'default':
                        continue
                    module_name = module + "-" + m
                    load_module(modules, module_name, self._config['module'][module][m], default)
                    frontier_all_nodes.extend(modules[module_name].nodes)
                    self._app_access_point[m] = module_name
                    #for backward support
                    if 'also_access_for' in self._config['module'][module][m]:
                        other_access_points = self._config['module'][module][m]['also_access_for'].split(",")
                        for point in other_access_points:
                            self._app_access_point[point] = module_name
            elif module == "kcs":
                logging.error("load kcs config")
                default = {}
                for m in self._config['module'][module]:
                    module_name = module + '-' + m
                    logging.error("kcs: " + module_name)
                    load_module(modules, module_name, self._config['module'][module][m], default)
                    self._kcs_config_point[m] = module_name
                    if 'also_access_for' in self._config['module'][module][m]:
                        other_access_points = self._config['module'][module][m]['also_access_for'].split(",")
                        for point in other_access_points:
                            self._kcs_config_point[point] = module_name
            else:
                default = {}
                load_module(modules, module, self._config['module'][module], default)

        #add frontier all but dont ping them
        modules['frontier'] = Module()
        modules['frontier'].nodes = frontier_all_nodes
        modules['frontier'].ping = False

        self._modules = modules
        self._etc = self._config['etc']
        self._timestamp = time.time()
        return 0

    def build_update_for_module(self, module, update_flag):
        packet = {}
        if module not in self._modules:
            return ERR_MODULE_NOT_EXIST, None

        packet['module'] = {}
        for m in self._modules[module].require:
            packet['module'][m] = []
            for n in self._modules[m].nodes:
                if n.status != NODE_STATUS_OFFLINE_UPDATING:
                    packet['module'][m].append([n.ip, n.port, n.status, n.etc])

        packet['etc'] = self._modules[module].etc
        packet['update_flag'] = update_flag

        packet['timestamp'] = self._timestamp
        return packet

    async def ping(self, ip, port, data):
        key = "%s:%d"%(ip,port)
        status = None
        try:
            if key not in self._pools:
                self._pools[key] = ClientSession(conn_timeout = 2)
            session = self._pools[key]
            async with session.post("%s:%d/ping?seq=0"%(ip if "http" in ip else ("http://" + ip), port), timeout = 2, json = {'data':data}) as resp:
                status = resp.status
                if resp.status != 200:
                    print("ping error,status=%d"%resp.status)
                    return - resp.status, None
                resp = await resp.json()
                return 0, resp
        except (aiohttp.ServerTimeoutError,aiohttp.ClientConnectionError,aiohttp.ClientConnectorError) as e:
            logging.error("%s dead! retcode=%d,err=%s"%(key,ERR_NODE_TIMEOUT_OR_DEAD, e))
            return ERR_NODE_TIMEOUT_OR_DEAD, None
        except Exception as e1:
            logging.error("%s weired! retcode=%d,err=%s,status=%s"%(key,ERR_NODE_TIMEOUT_OR_DEAD, e1,status))
            logging.error(e1)
            return ERR_NODE_TIMEOUT_OR_DEAD, None

    async def check_node(self, module, node, update_flag=0, sync_flag=1):
        now = time.time()
        retcode = 0
        ping_resp = {}
        if update_flag == 0:
            retcode, ping_resp = await self.ping(node.ip, node.port, None)
            if ping_resp:
                if 'error' in ping_resp and ping_resp['error']:
                    retcode = ping_resp['error']['code']

        if retcode == 0 and sync_flag:
            if  not ping_resp or 'timestamp' not in ping_resp['data'] or ping_resp['data']['timestamp'] < self._timestamp:
                #resync etc
                ping_data = self.build_update_for_module(module, update_flag)
                retcode, ping_resp = await self.ping(node.ip, node.port, ping_data)
                if ping_resp:
                    if 'error' in ping_resp and ping_resp['error']:
                        retcode = ping_resp['error']['code']
                    ping_resp = ping_resp['data']
                if retcode == 0:
                    if ping_resp and 'confirm_offline_update' in ping_resp and ping_resp['confirm_offline_update']:
                        node.status = NODE_STATUS_OFFLINE_UPDATING
                        self._last_check_nodes = 0
                        node.fail = 500
                        self._timestamp = now
                        logging.error("%s:%d[%s],offline for update!status=%d, fail=%d" % (node.ip, node.port, module, node.status, node.fail))
                    logging.error("%s:%d sync module info,module=%s"%(node.ip, node.port, module))
                else:
                    node.fail += 1
                    if node.fail > 5:
                        if node.status != NODE_STATUS_OFFLINE_UPDATING:
                            node.status = NODE_STATUS_RED
                        self._modules[module].last = now
                        self._timestamp = now
                    logging.error("%s:%d sync module info failed,module=%s, retcode=%d" % (node.ip, node.port, module, retcode))

            if node.fail > 0:
                node.fail = int(node.fail / 2)
                logging.error("%s:%d[%s] back online, delay=%d"%(node.ip, node.port, module, node.fail))
                if node.fail == 0:
                    node.status = NODE_STATUS_GREEN
                    self._modules[module].last = now
                    self._timestamp = now

        else:#if retcode == ERR_NODE_TIMEOUT_OR_DEAD:
            node.fail += 1
            if node.fail > 5:
                if node.status != NODE_STATUS_OFFLINE_UPDATING:
                    node.status = NODE_STATUS_RED
                self._modules[module].last = now
                self._timestamp = now
        node.last = now

        return retcode

    async def check_nodes(self, loop):
        for m in self._modules:
            #logging.error("check nodes,module=%s count=%d"%(m, len(self._modules[m].nodes)))
            if not self._modules[m].ping:
                continue
            offline_update_node = None
            offline_update_count = 0
            skip_check_node = 0
            up_nodes = []
            for nodes, sync_flag in [[self._modules[m].nodes, 1], [self._modules[m].https_nodes, 0]]:
                if len(nodes) == 0:
                    continue
                for n in nodes:
                    if n.status in [NODE_STATUS_RED , NODE_STATUS_OFFLINE_UPDATING]:
                        offline_update_count += 1
                    else:
                        up_nodes.append(n)
                random.shuffle(up_nodes)
                if len(up_nodes) > 0 and offline_update_count * 1.0 / len(nodes) < 0.5:
                    offline_update_node = up_nodes[0]
                for n in nodes:
                    update_flag = 1 if n ==offline_update_node else 0
                    #默认是要ping的, 所以不填和填1都是ping, 填0不ping
                    do_ping = n.etc.get("ping", 1)
                    if not do_ping:
                        skip_check_node += 1
                    else:
                        loop.create_task(self.check_node(m, n, update_flag, sync_flag))
        return

    def select_ips(self, nodes, module_etc, app_channel, appid):
        ips = []
        limit_ips = None
        app_channel_route = module_etc.get("app_channel_route", {})
        if app_channel_route:
            for route_plan_key in app_channel_route:
                route_plan = app_channel_route[route_plan_key]
                if appid in route_plan.get("appids", []) and app_channel in route_plan.get("app_channels", []):
                    limit_ips = set(route_plan.get("ips", []))

        select_range = []
        total_weight = 0
        for i in nodes:
            logging.error(f"node ip={i.ip}, port={i.port}")
            weight = i.etc.get("weight", 1)
            if i.status == NODE_STATUS_GREEN or i.status == NODE_STATUS_YELLOW:
                if weight == 0:
                    continue
                if not limit_ips or i.ip in limit_ips:
                    select_range.append(i)
                    total_weight += weight

        IP_COUNT = 2
        for k in range(IP_COUNT):
            choice = random.randint(1, total_weight) if total_weight > 1 else 1
            for i in select_range:
                weight = i.etc.get("weight", 1)
                choice -= weight
                if choice <= 0:
                    ips.append(i)
                    total_weight -= weight
                    select_range.remove(i)
                    break
            if not select_range or total_weight == 0:
                break
        return ips

    async def get_dns_v1(
        self, func, ac, device_platform, loc_time, latitude, 
        longtitude, city, iid, device_id, udid, openudid,
        device_type, os_api, os_version, client_version):
        
        if func not in self._app_access_point:
            return ERR_INVALID_REQUEST, None
        module = self._app_access_point[func]

        if module:
            '''
            ips = []
            for i in self._modules[module].nodes:
                if i.status == NODE_STATUS_GREEN:
                    ips.append(i)
                    if len(ips) > 1:
                        break
            random.shuffle(ips)
            for i in ips:
                self._modules[module].nodes.remove(i)
                self._modules[module].nodes.append(i)
            
            if len(ips) < 2:
                yellow_ips = []
                for i in self._modules[module].nodes:
                    if i.status == NODE_STATUS_YELLOW:
                        yellow_ips.append(i)
                        if len(yellow_ips) > 1:
                            break
                for i in yellow_ips:
                    self._modules[module].nodes.remove(i)
                    self._modules[module].nodes.append(i)
                ips.extend(yellow_ips)
            '''
            ips = self.select_ips(self._modules[module].nodes, self._modules[module].etc, "", func)
            if not ips:
                return ERR_MODULE_NOT_AVAILABLE, None
            data = []
            for ip in ips:
                self.stat_ip(ip)
                data.append({'type': 'ipport', 'ip': ip.ip, 'port': ip.port})
            return 0, data
        return ERR_INVALID_REQUEST, None

    async def get_dns(
            self, app_channel, appid, ac, device_platform, loc_time, latitude,
            longtitude, city, iid, device_id, uuid, openudid,
            device_type, os_api, os_version, client_version, dns_version):

        logging.info(f"enter func get_dns, appid={appid}")
        data = {}
        retcode, token_info = self.generate_token_info(app_channel, appid, ac, device_platform, loc_time, latitude,
                                                  longtitude, city, iid, device_id, uuid, openudid,
                                                  device_type, os_api, os_version, client_version, dns_version)
        if retcode == 0:
            logging.info("generate_token_info success")
            data.setdefault('token_info', token_info)
        else:
            logging.error("generate_token_info fail, result=%d" % (ERR_INVALID_REQUEST))
            return ERR_INVALID_REQUEST, None
        
        retcode, frontier_list = self.get_frontier_list(appid, app_channel)
        if retcode == 0:
            logging.info("get_frontier_list success")
            data.setdefault('frontier_list', frontier_list)
            data.setdefault('applog_list', frontier_list)
        else:
            logging.error("get_frontier_list fail, retcode=%d" % (retcode))
            return retcode, None

        logging.error("after get frontier list")
        retcode, kcs_list = self.get_kcs_list(appid, app_channel)
        logging.info(f"get kcs list, retcode={retcode}")
        if retcode == 0:
            data.setdefault('kcs_list', kcs_list)
        else:
            logging.error("get kcs list fail.")
        return 0, data

    def get_frontier_list(self, appid, app_channel):
        if appid not in self._app_access_point:
            return ERR_INVALID_REQUEST, None
        module = self._app_access_point[appid]

        if module:
            ips = []
            '''
            for i in self._modules[module].nodes:
                if i.status == NODE_STATUS_GREEN:
                    ips.append(i)
                    if len(ips) > 1:
                        break
            random.shuffle(ips)
            for i in ips:
                self._modules[module].nodes.remove(i)
                self._modules[module].nodes.append(i)

            if len(ips) < 2:
                yellow_ips = []
                for i in self._modules[module].nodes:
                    if i.status == NODE_STATUS_YELLOW:
                        yellow_ips.append(i)
                        if len(yellow_ips) > 1:
                            break
                for i in yellow_ips:
                    self._modules[module].nodes.remove(i)
                    self._modules[module].nodes.append(i)
                ips.extend(yellow_ips)
            '''
            ips = self.select_ips(self._modules[module].https_nodes, self._modules[module].etc, app_channel, appid)
            if not ips:
                return ERR_MODULE_NOT_AVAILABLE, None
            frontier_list = []
            for ip in ips:
                self.stat_ip(ip)
                frontier_list.append(
                    {'type': 'protocal', 'host': ip.ip + ':' + str(ip.port)})
            return 0, frontier_list
        return ERR_INVALID_REQUEST, None

    def get_kcs_list(self, appid, app_channel):
        if appid not in self._kcs_config_point:
            return ERR_INVALID_REQUEST, None
        module = self._kcs_config_point[appid]
        if module:
            nodes = self.select_ips(self._modules[module].https_nodes, self._modules[module].etc, app_channel, appid)
            if not nodes:
                return ERR_MODULE_NOT_AVAILABLE, None
            kcs_list = []
            for node in nodes:
                self.stat_ip(node)
                host = node.ip + ':' + str(node.port)
                if node.uri:
                    host += node.uri  # uri 需要以 / 开头
                kcs_list.append({'type': 'protocal', 'host': host})

            return 0, kcs_list
        return ERR_INVALID_REQUEST, None

    def generate_token_info(self, app_channel, appid, ac, device_platform, loc_time, latitude,
                            longtitude, city, iid, device_id, uuid, openudid,
                            device_type, os_api, os_version, client_version, dns_version):
        logging.info("generate_token_info func start")
        token_info = {}
        token = {}
        
        authver_manager = AuthverManager()
        ticket_authver, ticket_key = authver_manager.generate_authver_key()
        token_authver, token_key = authver_manager.generate_authver_key()
        logging.info("authver, ticket=%s, token=%s" %(ticket_authver, token_authver))
        
        timestamp = int(time.time())
        token.setdefault('app_channel', app_channel)
        token.setdefault('imei', uuid)
        token.setdefault('appid', appid)
        token.setdefault('client_ver', client_version)
        token.setdefault('ticket_authver', ticket_authver)
        token.setdefault('token_authver', token_authver)
        token.setdefault('timestamp', timestamp)
        encrypted_token = AEScoder.encrypt(token_key, json.dumps(token))
        authorization = Utils.generate_authorization("0", ticket_authver, token_authver, encrypted_token)

        token_info.setdefault("ticket_key", ticket_key)
        token_info.setdefault("authorization", authorization)
        return 0, token_info


    def stat_ip(self, node):
        key = "%s:%d"%(node.ip,node.port)
        if key not in self._ips_stat:
            self._ips_stat[key] = 0
        self._ips_stat[key] += 1
    
    def print_stat(self):
        if self._ips_stat:
            logging.error("========== dns stat ==========")
            for key in self._ips_stat:
                logging.error("%20s:%d"%(key, self._ips_stat[key]))
                
        logging.error("========== module stat ==========")
        for module in self._modules:
            if not self._modules[module].ping:
                continue
            for node in self._modules[module].nodes:
                logging.error("%s %s:%d=%d %d"%(module, node.ip, node.port, node.status, node.fail))
        logging.error("")
        self._ips_stat = {}

    def init_service(self):
        self.add_handler(self.get_dns_v1)
        self.add_handler(self.get_dns)

        svr = self.startup_webservice(self._config['service']['port'])
        return svr

    async def main_loop(self):
        last_print_stat = time.time()

        while True:
            now = time.time()
            if now - last_print_stat > 60:
                self.print_stat()
                last_print_stat = now
            if now - self._last_check_nodes > 2:
                await self.check_nodes(asyncio.get_event_loop())
                self._last_check_nodes = now
            await asyncio.sleep(0.01)

    def start(self, cfg):
        self.load_config()
        loop = asyncio.get_event_loop()
        loop.create_task(self.init_service())
        loop.create_task(self.main_loop())

        loop.run_forever()

if __name__ == "__main__":
    app = RCMDMaster()
    app.main(__file__.split('/')[-1].split(".")[0])
